import os
import numpy as np
from tqdm import tqdm
from PIL import Image
from einops import rearrange

import torch
import torchvision
from torch import Tensor
from torchvision.utils import make_grid
from torchvision.transforms.functional import to_tensor
from PIL import Image, ImageDraw, ImageFont


def save_video_tensor_to_mp4(video, path, fps):
    # b,c,t,h,w
    video = video.detach().cpu()
    video = torch.clamp(video.float(), -1.0, 1.0)
    n = video.shape[0]
    video = video.permute(2, 0, 1, 3, 4)  # t,n,c,h,w
    frame_grids = [
        torchvision.utils.make_grid(framesheet, nrow=int(n)) for framesheet in video
    ]  # [3, 1*h, n*w]
    grid = torch.stack(frame_grids, dim=0)  # stack in temporal dim [t, 3, n*h, w]
    grid = (grid + 1.0) / 2.0
    grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
    torchvision.io.write_video(
        path, grid, fps=fps, video_codec="h264", options={"crf": "10"}
    )


def save_video_tensor_to_frames(video, dir):
    os.makedirs(dir, exist_ok=True)
    # b,c,t,h,w
    video = video.detach().cpu()
    video = torch.clamp(video.float(), -1.0, 1.0)
    n = video.shape[0]
    assert n == 1
    video = video[0]  # cthw
    video = video.permute(1, 2, 3, 0)  # thwc
    # video = video.permute(2, 0, 1, 3, 4) # t,n,c,h,w
    video = (video + 1.0) / 2.0 * 255
    video = video.to(torch.uint8).numpy()
    for i in range(video.shape[0]):
        img = video[i]  # hwc
        image = Image.fromarray(img)
        image.save(os.path.join(dir, f"frame{i:03d}.jpg"), q=95)


def frames_to_mp4(frame_dir, output_path, fps):
    def read_first_n_frames(d: os.PathLike, num_frames: int):
        if num_frames:
            images = [
                Image.open(os.path.join(d, f))
                for f in sorted(os.listdir(d))[:num_frames]
            ]
        else:
            images = [Image.open(os.path.join(d, f)) for f in sorted(os.listdir(d))]
        images = [to_tensor(x) for x in images]
        return torch.stack(images)

    videos = read_first_n_frames(frame_dir, num_frames=None)
    videos = videos.mul(255).to(torch.uint8).permute(0, 2, 3, 1)
    torchvision.io.write_video(
        output_path, videos, fps=fps, video_codec="h264", options={"crf": "10"}
    )


def tensor_to_mp4(video, savepath, fps, rescale=True, nrow=None):
    """
    video: torch.Tensor, b,c,t,h,w, 0-1
    if -1~1, enable rescale=True
    """
    n = video.shape[0]
    video = video.permute(2, 0, 1, 3, 4)  # t,n,c,h,w
    nrow = int(np.sqrt(n)) if nrow is None else nrow
    frame_grids = [
        torchvision.utils.make_grid(framesheet, nrow=nrow) for framesheet in video
    ]  # [3, grid_h, grid_w]
    grid = torch.stack(
        frame_grids, dim=0
    )  # stack in temporal dim [T, 3, grid_h, grid_w]
    grid = torch.clamp(grid.float(), -1.0, 1.0)
    if rescale:
        grid = (grid + 1.0) / 2.0
    grid = (
        (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
    )  # [T, 3, grid_h, grid_w] -> [T, grid_h, grid_w, 3]
    # print(f'Save video to {savepath}')
    torchvision.io.write_video(
        savepath, grid, fps=fps, video_codec="h264", options={"crf": "10"}
    )


def tensor2videogrids(video, root, filename, fps, rescale=True, clamp=True):
    assert video.dim() == 5  # b,c,t,h,w
    assert isinstance(video, torch.Tensor)

    video = video.detach().cpu()
    if clamp:
        video = torch.clamp(video, -1.0, 1.0)
    n = video.shape[0]
    video = video.permute(2, 0, 1, 3, 4)  # t,n,c,h,w
    frame_grids = [
        torchvision.utils.make_grid(framesheet, nrow=int(np.sqrt(n)))
        for framesheet in video
    ]  # [3, grid_h, grid_w]
    grid = torch.stack(
        frame_grids, dim=0
    )  # stack in temporal dim [T, 3, grid_h, grid_w]
    if rescale:
        grid = (grid + 1.0) / 2.0
    grid = (
        (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
    )  # [T, 3, grid_h, grid_w] -> [T, grid_h, grid_w, 3]
    path = os.path.join(root, filename)
    # print('Save video ...')
    torchvision.io.write_video(
        path, grid, fps=fps, video_codec="h264", options={"crf": "10"}
    )
    # print('Finish!')


def log_txt_as_img(wh, xc, size=10):
    # wh a tuple of (width, height)
    # xc a list of captions to plot
    b = len(xc)
    txts = list()
    for bi in range(b):
        txt = Image.new("RGB", wh, color="white")
        draw = ImageDraw.Draw(txt)
        font = ImageFont.truetype("data/DejaVuSans.ttf", size=size)
        nc = int(40 * (wh[0] / 256))
        lines = "\n".join(
            xc[bi][start : start + nc] for start in range(0, len(xc[bi]), nc)
        )

        try:
            draw.text((0, 0), lines, fill="black", font=font)
        except UnicodeEncodeError:
            print("Cant encode string for logging. Skipping.")

        txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
        txts.append(txt)
    txts = np.stack(txts)
    txts = torch.tensor(txts)
    return txts


def log_local(batch_logs, save_dir, filename, save_fps=10, rescale=True):
    if batch_logs is None:
        return None
    """ save images and videos from images dict """

    def save_img_grid(grid, path, rescale):
        if rescale:
            grid = (grid + 1.0) / 2.0  # -1,1 -> 0,1; c,h,w
        grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
        grid = grid.numpy()
        grid = (grid * 255).astype(np.uint8)
        os.makedirs(os.path.split(path)[0], exist_ok=True)
        Image.fromarray(grid).save(path)

    for key in batch_logs:
        value = batch_logs[key]
        if isinstance(value, list) and isinstance(value[0], str):
            ## a batch of captions
            path = os.path.join(save_dir, "%s-%s.txt" % (key, filename))
            with open(path, "w") as f:
                for i, txt in enumerate(value):
                    f.write(f"idx={i}, txt={txt}\n")
                f.close()
        elif isinstance(value, torch.Tensor) and value.dim() == 5:
            ## save video grids
            video = value  # b,c,t,h,w
            ## only save grayscale or rgb mode
            if video.shape[1] != 1 and video.shape[1] != 3:
                continue
            n = video.shape[0]
            video = video.permute(2, 0, 1, 3, 4)  # t,n,c,h,w
            frame_grids = [
                torchvision.utils.make_grid(framesheet, nrow=int(1))
                for framesheet in video
            ]  # [3, n*h, 1*w]
            grid = torch.stack(
                frame_grids, dim=0
            )  # stack in temporal dim [t, 3, n*h, w]
            if rescale:
                grid = (grid + 1.0) / 2.0
            grid = (grid * 255).to(torch.uint8).permute(0, 2, 3, 1)
            path = os.path.join(save_dir, "%s-%s.mp4" % (key, filename))
            torchvision.io.write_video(
                path, grid, fps=save_fps, video_codec="h264", options={"crf": "10"}
            )

            ## save frame sheet
            img = value
            video_frames = rearrange(img, "b c t h w -> (b t) c h w")
            t = img.shape[2]
            grid = torchvision.utils.make_grid(video_frames, nrow=t)
            path = os.path.join(save_dir, "%s-%s.jpg" % (key, filename))
            # save_img_grid(grid, path, rescale)
        elif isinstance(value, torch.Tensor) and value.dim() == 4:
            ## save image grids
            img = value
            ## only save grayscale or rgb mode
            if img.shape[1] != 1 and img.shape[1] != 3:
                continue
            n = img.shape[0]
            grid = torchvision.utils.make_grid(img, nrow=1)
            path = os.path.join(save_dir, "%s-%s.jpg" % (key, filename))
            save_img_grid(grid, path, rescale)
        else:
            pass


def prepare_to_log(batch_logs, max_images=100000, clamp=True):
    if batch_logs is None:
        return None
    # process
    for key in batch_logs:
        if batch_logs[key] is not None:
            N = (
                batch_logs[key].shape[0]
                if hasattr(batch_logs[key], "shape")
                else len(batch_logs[key])
            )
            N = min(N, max_images)
            batch_logs[key] = batch_logs[key][:N]
            ## in batch_logs: images <batched tensor> & caption <text list>
            if isinstance(batch_logs[key], torch.Tensor):
                batch_logs[key] = batch_logs[key].detach().cpu()
                if clamp:
                    try:
                        batch_logs[key] = torch.clamp(
                            batch_logs[key].float(), -1.0, 1.0
                        )
                    except RuntimeError:
                        print("clamp_scalar_cpu not implemented for Half")
    return batch_logs


# ----------------------------------------------------------------------------------------------


def fill_with_black_squares(video, desired_len: int) -> Tensor:
    if len(video) >= desired_len:
        return video

    return torch.cat(
        [
            video,
            torch.zeros_like(video[0])
            .unsqueeze(0)
            .repeat(desired_len - len(video), 1, 1, 1),
        ],
        dim=0,
    )


# ----------------------------------------------------------------------------------------------
def load_num_videos(data_path, num_videos):
    # first argument can be either data_path of np array
    if isinstance(data_path, str):
        videos = np.load(data_path)["arr_0"]  # NTHWC
    elif isinstance(data_path, np.ndarray):
        videos = data_path
    else:
        raise Exception

    if num_videos is not None:
        videos = videos[:num_videos, :, :, :, :]
    return videos


def npz_to_video_grid(
    data_path, out_path, num_frames, fps, num_videos=None, nrow=None, verbose=True
):
    # videos = torch.tensor(np.load(data_path)['arr_0']).permute(0,1,4,2,3).div_(255).mul_(2) - 1.0 # NTHWC->NTCHW, np int -> torch tensor 0-1
    if isinstance(data_path, str):
        videos = load_num_videos(data_path, num_videos)
    elif isinstance(data_path, np.ndarray):
        videos = data_path
    else:
        raise Exception
    n, t, h, w, c = videos.shape
    videos_th = []
    for i in range(n):
        video = videos[i, :, :, :, :]
        images = [video[j, :, :, :] for j in range(t)]
        images = [to_tensor(img) for img in images]
        video = torch.stack(images)
        videos_th.append(video)
    if verbose:
        videos = [
            fill_with_black_squares(v, num_frames)
            for v in tqdm(videos_th, desc="Adding empty frames")
        ]  # NTCHW
    else:
        videos = [fill_with_black_squares(v, num_frames) for v in videos_th]  # NTCHW

    frame_grids = torch.stack(videos).permute(1, 0, 2, 3, 4)  # [T, N, C, H, W]
    if nrow is None:
        nrow = int(np.ceil(np.sqrt(n)))
    if verbose:
        frame_grids = [
            make_grid(fs, nrow=nrow) for fs in tqdm(frame_grids, desc="Making grids")
        ]
    else:
        frame_grids = [make_grid(fs, nrow=nrow) for fs in frame_grids]

    if os.path.dirname(out_path) != "":
        os.makedirs(os.path.dirname(out_path), exist_ok=True)
    frame_grids = (
        (torch.stack(frame_grids) * 255).to(torch.uint8).permute(0, 2, 3, 1)
    )  # [T, H, W, C]
    torchvision.io.write_video(
        out_path, frame_grids, fps=fps, video_codec="h264", options={"crf": "10"}
    )
